0%

Triton shared memory

在使用triton时,可以使用shared memory方式加速数据传输的过程

shared memory 与 cuda shared memory

  • shared memory模式下,triton会在数据存放于/dev/shm目录下,之后将文件路径通过rpc接口发送给triton server
  • cuda shared memory模式下,triton会在指定显卡上申请一块显存区域,之后直接将数据存放于显存中并通过rpc通知triton server

两种shm模式均会使用额外的显存或内存区域,并且依赖程序主动注销申请的空间,因此程序编写时需要特别注意资源的回收和释放,否则将会导致内存或显存泄露

考虑到一般情况下内存空间会远大于显存空间,在发生泄漏时内存空间也比显存空间更好回收,因此虽然cuda-shm模式的速度相比shm更快,但是我依旧更推荐使用shm模式。

shm memory pool

使用shm模式的最简单的方式是每次使用随机或顺序递增的文件名去申请内存,在使用结束后立刻释放,但是考虑到保持内存用量的稳定以及复用内存空间来进一步加速,可以使用shm内存池来更好的使用shm模式

内存池的代码如下,主要特性:

  1. 根据用户提供的shm_name_prefix, shm_key_prefix创建指定数量的共享内存对象,如果使用该名称的对象已经存在则原有对象删除重新创建,这一步保证了过去泄露的内存空间能够被回收
  2. 将申请的所有内存空间注册到triton server
  3. 支持asyncio,使用await语句等待空闲内存对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tritonclient.grpc as grpcclient
import tritonclient.utils.shared_memory as shm
import asyncio
from asyncio import Queue
from typing import Awaitable
import logging
import os
from pathlib import Path

class ShmRegion(object):
def __init__(self, triton_client: grpcclient, shm_queue: Queue, max_data_size, shm_name, shm_key):
self.name = shm_name
self.key = shm_key
self.shm_queue = shm_queue
self.size = max_data_size
self.triton_client: grpcclient = triton_client
# 这里要注意,也许triton升级之后shm就不放这里了,这样清理过期资源不是长久之计
self.shm_path = Path(f"/dev/shm/{self.key}")
if self.shm_path.exists():
os.remove(self.shm_path)
self.triton_client.unregister_system_shared_memory(self.name)
self.handle = shm.create_shared_memory_region(self.name, self.key, max_data_size)
self.triton_client.register_system_shared_memory(self.name, self.key, max_data_size)
logging.info(f"shm region {self.name} registered")

def addToQueue(self, shm_queue=None):
if shm_queue is not None:
assert self.shm_queue is None or self.shm_queue is shm_queue
self.shm_queue = shm_queue
else:
assert self.shm_queue is not None
self.shm_queue.put_nowait(self)

def __enter__(self):
return self

def __exit__(self, type, value, trace):
self.addToQueue()

def __del__(self):
logging.info(f"shm region {self.name} removed")
self.triton_client.unregister_system_shared_memory(self.name)
shm.destroy_shared_memory_region(self.handle)


class ShmTritonClient(object):
def __init__(self, triton_client, max_queue_size, max_data_size, shm_name_prefix, shm_key_prefix):
self.triton_client = triton_client
# self.triton_client.unregister_system_shared_memory()
# 这似乎是一个异步请求,执行后会导致后面的注册失效
# 等到能拿到loop的时候再执行
# self.shm_queue = Queue(maxsize=max_queue_size, loop=None)
self.shm_queue = None
self.max_queue_size = max_queue_size
self.regions = []
self.registered_regions = 0
for i in range(max_queue_size):
region = ShmRegion(self.triton_client, self.shm_queue, max_data_size, f"{shm_name_prefix}_{i}", f"{shm_key_prefix}_{i}")
self.regions.append(region)
# 等到能拿到loop的时候再执行
# region.addToQueue()

def getRegion(self) -> Awaitable[ShmRegion] :
"""每次调用时如果还有未被注册的region, 则注册一个, """
if self.shm_queue is None:
loop = asyncio.get_running_loop()
self.shm_queue = Queue(maxsize=self.max_queue_size, loop=loop)
if len(self.regions) > self.registered_regions:
self.regions[self.registered_regions].addToQueue(self.shm_queue)
self.registered_regions += 1
return self.shm_queue.get()

使用例子如下:

  1. 初始化
1
2
3
4
5
6
7
import numpy as np
import tritonclient.grpc as grpcclient
triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
# 计算你的数据的最大内存使用,比如最大的batch=32,最大size=3*512*512
max_data = np.zeros(shape=(32, 3, 512, 512), dtype=np.float32)
byte_size = max_data.size * max_data.itemsize
shm_client = ShmTritonClient(triton_client, max_queue_size=16, max_data_size=byte_size, shm_name_prefix="shm_data", shm_key_prefix="/shm_data")
  1. 调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tritonclient.utils as utils
import tritonclient.utils.shared_memory as shm
from await_infer import await_infer

async def infer(img, triton_client):
input_byte_size = img.size * img.itemsize
# 一次性获取2个region分别用于输入输出
with (await shm_client.getRegion()) as inpRegion, \
(await shm_client.getRegion()) as outRegion:
shm.set_shared_memory_region(inpRegion.handle, [img])
inputs = [grpcclient.InferInput('input', [*img.shape], "FP32")]
# 输入为FP32的图片,很大,需要使用shm加速
inputs[0].set_shared_memory(inpRegion.name, input_byte_size)
# 假设网络有2个输出,第一个输出比较大使用shm模式,第二个输出很小,直接使用grpc完成传输
outputs = [grpcclient.InferRequestedOutput(name) for name in ['shm_output''plain_output']]
outputs[0].set_shared_memory(outRegion.name, outRegion.size)
# await_inferd的代码参考文章 "Nvidia Triton 使用教程"
# https://maple.link/2021/06/10/Nvidia%20Triton%20Server%E7%9A%84%E4%BD%BF%E7%94%A8/
results = await await_infer(
triton_client = triton_client,
model_name = "model_name",
inputs = inputs,
outputs = outputs
)

shm_output = results.get_output("shm_output")
shm_output = shm.get_contents_as_numpy(
outRegion.handle, utils.triton_to_np_dtype(shm_output.datatype),
shm_output.shape)
plain_output = results.as_numpy('plain_output')
# copy前shm_output, plain_output都是只读的,无法编辑
shm_output = shm_output.copy()
plain_output = plain_output.copy()

return shm_output, plain_output